import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import time


def main():
    #----------------1.加载数据集，并进行数据预处理----------------
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  #均值和标准差，来标准化

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    train_set = torchvision.datasets.CIFAR10(root='./data',  #下载路径
                                            train=True,      #导入训练集
                                             download=False, 
                                             #download=True, 
                                             transform=transform)   #图像预处理
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                               shuffle=True, 
                                               num_workers=0)  #载入数据的线程数

    # 10000张验证图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
                                             shuffle=False, num_workers=0)
    val_data_iter = iter(val_loader)            #转化为迭代器，即可用next取出各个元素
    val_image, val_label = next(val_data_iter)
    
    # classes = ('plane', 'car', 'bird', 'cat',
    #            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    #2. 定义训练的模型
    net = LeNet()
    loss_function = nn.CrossEntropyLoss()               # 定义损失函数为交叉熵损失函数 
    optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器（训练参数，学习率）

    for epoch in range(5):  # loop over the dataset multiple times  遍历训练集
        time_start = time.perf_counter()   #计算耗时

        running_loss = 0.0
        for step, data in enumerate(train_loader, start=0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data                       # 获取训练集的图像和标签

            # zero the parameter gradients  梯度清零
            # ------------如果不清除历史梯度，会对历史梯度进行累积
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)                   # 正向传播
            loss = loss_function(outputs, labels)   # 计算损失（预测，真实）
            loss.backward()                         # 反向传播
            optimizer.step()                        # 优化器更新参数

            # print statistics
            running_loss += loss.item()  # 累积误差：每train_batch_size=36数据，每次训练加一遍
            if step % 500 == 499:    # print every 500 mini-batches 每500步打印一次
                with torch.no_grad():   # 在以下步骤中（验证过程中）不用计算每个节点的损失梯度，防止内存占用
                    outputs = net(val_image)  # [batch, 10]
                    predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引（标签）作为预测输出
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0) #tensor转为数值：item

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500,  # 累积500次打印一次，故进行平均
                          accuracy))
                          
                    time_end = time.perf_counter()      
                    print('%f s' % (time_end - time_start))        # 打印耗时
                    
                    #下一次重置
                    time_start = time_end
                    running_loss = 0.0

    print('Finished Training')

    # 保存训练得到的参数
    save_path = './Lenet.pth'
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()
